import pickle
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
# from laplax import ADLaplace
plotly.offline.init_notebook_mode()
def plot_all(varient=''):
all_pdfs = []
all_labels = []
varient= str(varient)
x = jnp.linspace(0,6,10000)
with open('./results_data/linear_regression_Ajax'+varient,'rb') as f:
posterior = pickle.load(f)
ajax_vi_samples = posterior.sample(seed = jax.random.PRNGKey(10), sample_shape = (10000,))
for i in range(2):
kde_black = gaussian_kde(ajax_vi_samples["theta"][:,i])
y = kde_black(x)
all_pdfs.append(y)
all_labels.append('Ajax VI theta0')
all_labels.append('Ajax VI theta1')
with open('./results_data/linear_regression_laplace'+varient,'rb') as f:
laplace = pickle.load(f)
loc_m = laplace['mean']
std = jnp.sqrt(jnp.diag(laplace['cov']))
for i in range(2):
y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
all_pdfs.append(y)
all_labels.append('Laplace approximation theta0')
all_labels.append('Laplace approximation theta1')
with open('./results_data/MCMC_Blackjax'+varient,'rb') as f:
black_samples = pickle.load(f)
for i in range(2):
kde_black = gaussian_kde(black_samples.position['theta'][:,i])
pdf_black = kde_black(x)
all_pdfs.append(pdf_black)
all_labels.append('Blackjax rmh theta0')
all_labels.append( 'Blackjax rmh theta1')
with open ("./results_data/linear_regression_true_posterior",'rb') as f:
true_params = pickle.load(f)
mean = true_params["mean"]
cov = true_params["covariance"]
true_normal1 = tfd.Normal(mean[0],jnp.sqrt(cov[0,0]))
true_pdf1 = true_normal1.prob(x)
all_pdfs.append(true_pdf1)
true_normal2 = tfd.Normal(mean[1],jnp.sqrt(cov[1,1]))
true_pdf2 = true_normal2.prob(x)
all_pdfs.append(true_pdf2)
all_labels.append("true theta0")
all_labels.append("true theta1")
all_pdfs = jnp.array(all_pdfs).reshape((-1))
no_estimates = len(all_labels)
all_labels_repeated = [item for item in all_labels for i in range(x.shape[0])]
x_repeated = jnp.tile(x,no_estimates)
to_df = {
"theta":x_repeated,
"PDF":all_pdfs,
"label": all_labels_repeated
}
df = pd.DataFrame(to_df)
fig = px.line(to_df,"theta","PDF",color="label",title=f"Linear regression posterior")
fig.show()
plot_all()
!jupyter nbconvert --to HTML linear_regression_results.ipynb
[NbConvertApp] Converting notebook linear_regression_results.ipynb to HTML [NbConvertApp] Writing 6555574 bytes to linear_regression_results.html